/*
 *    NaiveBayes.java
 *    Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
 *    @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */
package moa.classifiers;

import moa.core.AutoExpandVector;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.StringUtils;
import weka.core.Instance;

public class NaiveBayes extends AbstractClassifier {

	private static final long serialVersionUID = 1L;

	@SuppressWarnings("hiding")
	public static final String classifierPurposeString = "Naive Bayes classifier: performs classic bayesian prediction while making naive assumption that all inputs are independent.";

	protected DoubleVector observedClassDistribution;

	protected AutoExpandVector<AttributeClassObserver> attributeObservers;

	@Override
	public void resetLearningImpl() {
		this.observedClassDistribution = new DoubleVector();
		this.attributeObservers = new AutoExpandVector<AttributeClassObserver>();
	}

	@Override
	public void trainOnInstanceImpl(Instance inst) {
		this.observedClassDistribution.addToValue((int) inst.classValue(), inst
				.weight());
		for (int i = 0; i < inst.numAttributes() - 1; i++) {
			int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst);
			AttributeClassObserver obs = this.attributeObservers.get(i);
			if (obs == null) {
				obs = inst.attribute(instAttIndex).isNominal() ? newNominalClassObserver()
						: newNumericClassObserver();
				this.attributeObservers.set(i, obs);
			}
			obs.observeAttributeClass(inst.value(instAttIndex), (int) inst
					.classValue(), inst.weight());
		}
	}

	public double[] getVotesForInstance(Instance inst) {
		return doNaiveBayesPrediction(inst, this.observedClassDistribution,
				this.attributeObservers);
	}

	@Override
	protected Measurement[] getModelMeasurementsImpl() {
		return null;
	}

	@Override
	public void getModelDescription(StringBuilder out, int indent) {
		for (int i = 0; i < this.observedClassDistribution.numValues(); i++) {
			StringUtils.appendIndented(out, indent, "Observations for ");
			out.append(getClassNameString());
			out.append(" = ");
			out.append(getClassLabelString(i));
			out.append(":");
			StringUtils.appendNewlineIndented(out, indent + 1,
					"Total observed weight = ");
			out.append(this.observedClassDistribution.getValue(i));
			out.append(" / prob = ");
			out.append(this.observedClassDistribution.getValue(i)
					/ this.observedClassDistribution.sumOfValues());
			for (int j = 0; j < this.attributeObservers.size(); j++) {
				StringUtils.appendNewlineIndented(out, indent + 1,
						"Observations for ");
				out.append(getAttributeNameString(j));
				out.append(": ");
				// TODO: implement observer output
				out.append(this.attributeObservers.get(j));
			}
			StringUtils.appendNewline(out);
		}
	}

	public boolean isRandomizable() {
		return false;
	}

	protected AttributeClassObserver newNominalClassObserver() {
		return new NominalAttributeClassObserver();
	}

	protected AttributeClassObserver newNumericClassObserver() {
		return new GaussianNumericAttributeClassObserver();
	}

	public static double[] doNaiveBayesPrediction(Instance inst,
			DoubleVector observedClassDistribution,
			AutoExpandVector<AttributeClassObserver> attributeObservers) {
		double[] votes = new double[observedClassDistribution.numValues()];
		double observedClassSum = observedClassDistribution.sumOfValues();
		for (int classIndex = 0; classIndex < votes.length; classIndex++) {
			votes[classIndex] = observedClassDistribution.getValue(classIndex)
					/ observedClassSum;
			for (int attIndex = 0; attIndex < inst.numAttributes() - 1; attIndex++) {
				int instAttIndex = modelAttIndexToInstanceAttIndex(attIndex,
						inst);
				AttributeClassObserver obs = attributeObservers.get(attIndex);
				if ((obs != null) && !inst.isMissing(instAttIndex)) {
					votes[classIndex] *= obs
							.probabilityOfAttributeValueGivenClass(inst
									.value(instAttIndex), classIndex);
				}
			}
		}
		// TODO: need logic to prevent underflow?
		return votes;
	}

	public void manageMemory(int currentByteSize, int maxByteSize) {
		// TODO Auto-generated method stub

	}

}
